#include <iostream>
#include <time.h>
#include <mpi.h>
using namespace std;
#define SIZE 512
int main(void)
{

    const int m = SIZE, n = SIZE, k = SIZE;
    int comm_sz; //进程的数量
    int my_rank; //进程的编号
    MPI_Init(NULL, NULL);
    MPI_Comm_size(MPI_COMM_WORLD, &comm_sz);
    MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
    const int row_range = m / comm_sz;
    int A_size = row_range * n;
    int C_size = row_range * k;
    double(*B)[k] = new double[n][k];
    double(*A)[n] = new double[m][n];
    double(*C)[k] = new double[m][k];
    double(*buf_A)[n] = new double[row_range][n];
    double(*buf_C)[k] = new double[row_range][k];

    if (my_rank == 0)
    {
        for (int i = 0; i < m; i++)
        {
            for (int j = 0; j < n; j++)
            {
                A[i][j] = i * m + j + 1;
            }
        }
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < k; j++)
            {
                B[i][j] = i * n + j + 1;
            }
        }
    }

    double start, end;
    MPI_Barrier(MPI_COMM_WORLD);
    start = MPI_Wtime();

    MPI_Bcast(&B[0][0], n * k, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    MPI_Scatter(&A[0][0], A_size, MPI_DOUBLE, &buf_A[0][0], A_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    MPI_Barrier(MPI_COMM_WORLD);

    for (int i = 0; i < row_range; i++)
    {
        for (int j = 0; j < k; j++)
        {
            buf_C[i][j] = 0;
            for (int l = 0; l < n; l++)
            {
                buf_C[i][j] += B[l][j] * buf_A[i][l];
            }
        }
    }

    MPI_Gather(&buf_C[0][0], C_size, MPI_DOUBLE, &C[0][0], C_size, MPI_DOUBLE, 0, MPI_COMM_WORLD);

    MPI_Barrier(MPI_COMM_WORLD);
    end = MPI_Wtime();
    if (my_rank == 0)
    {
        if (SIZE <= 8)
        {
            for (int i = 0; i < m; i++)
            {
                for (int j = 0; j < k; j++)
                {
                    printf("%lf ", C[i][j]);
                }
                putchar('\n');
            }
        }
        printf("%lf s", (double)(end - start));
    }

    MPI_Finalize();
    delete[] A;
    delete[] B;
    delete[] C;
    delete[] buf_C;
    delete[] buf_A;
    return 0;
}